In [15]:
import requests
from io import BytesIO
import numpy as np
import cv2,math
from PIL import Image
from tqdm import tqdm,trange
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torchvision import models, transforms

from skimage.data import astronaut
from skimage.color import rgb2gray
from skimage.filters import sobel
from skimage.segmentation import slic,mark_boundaries
from skimage.util import img_as_float
from sklearn.utils import check_random_state


def mfpp(model,
         input,
         img_file,
         target=None,
         seed=0,
         num_masks=20000,
         resize_offset=2.0,
         layer=5,
         batch_size=32,
         p_1=0.5,
         resize_mode='bilinear'):
    r"""MFPP.

    Args:
        model (:class:`torch.nn.Module`): a model.
        input (:class:`torch.Tensor`): input tensor.
        seed (int, optional): manual seed used to generate random numbers.
            Default: ``0``.
        num_masks (int, optional): number of MFPP random masks to use.
            Default: ``8000``.
        resize_offset(float): the offset for resized image for crop. Default: ``2.5``.
        layer (int): the number of segments style. Default: ``5``.
        batch_size (int, optional): batch size to use. Default: ``128``.
        p_1 (float, optional): with prob p_1, a low-res cell is set to 1;
            otherwise, it's 1. Default: ``0.5``.
        resize_mode (str, optional): If resize is not None, use this mode for
            the resize function. Default: ``'bilinear'``.

    Returns:
        :class:`torch.Tensor`: MFPP saliency map.
    """
    SEG_COEFF=50 #relationship between segmente number and index. It's fixed to 50.
#     response = requests.get(url)
# #     print("response",response)
#     img = Image.open(BytesIO(response.content))
    img = Image.open(img_file)
    img=img.resize((224,224))

    IMAGE_SIZE = (224, 224, 3)

    segments=[[],[]]*5
    n_features=[[]]*5  
    
#     plt.figure(figsize=(10,40),dpi=120) 
    for i in range(0,layer):
        num_segments=SEG_COEFF*(2**i)
        print("n_segments:",num_segments)
        segments[i]=slic(img,n_segments=num_segments,compactness=10,sigma=1)
        n_features[i] = np.unique(np.asarray(segments[i])).shape[0]
        print ("n_features: ",n_features[i] )

    with torch.no_grad():
        # Get device of input (i.e., GPU).
        dev = input.device
        # Initialize saliency mask and mask normalization term.
        input_shape = input.shape
        saliency_shape = list(input_shape)
        print(input_shape)

        height = input_shape[2]
        width = input_shape[3]
        
        H=height + math.floor(resize_offset*height)
        W=width + math.floor(resize_offset*width)

        out = model(input)
        num_classes = out.shape[1]

        saliency_shape[1] = num_classes
        saliency = torch.zeros(saliency_shape, device=dev)

        # Save current random number generator state.
        state = torch.get_rng_state()

        # Set seed.
        torch.manual_seed(seed)

        num_chunks = (num_masks + batch_size - 1) // batch_size
#         print("num_chunks:",num_chunks)
        layer_group=num_chunks//layer
#         print("layer_group:",layer_group)
        
        for chunk in trange(num_chunks):
#             print("chunk:",chunk)
            # Generate MFPP random masks on the fly.
            mask_bs = min(num_masks - batch_size * chunk, batch_size)
#             print("mask_bs ",mask_bs) #32
            layer_index=chunk//layer_group
#             print("layer_index:",layer_index) #0~4
            np_masks = np.zeros((mask_bs, ) + IMAGE_SIZE[:2], dtype=np.float32)
#             print("init_masks.shape:",np_masks.shape) # (32, 224, 224)
            data= np.random.choice([0, 1], size=n_features[layer_index], p=[1 - p_1, p_1])
#                 print("data:",data)
            zeros = np.where(data == 0)[0]
#                 print("zeros:",zeros)
            mask = np.zeros(segments[layer_index].shape).astype(float)
            for z in zeros:
#               print("z:",z)
                mask[segments[layer_index] == z] = 1.0
#               print("mask:",mask)
            mask= Image.fromarray(mask * 255.)
#               plt.imshow(mask)

            mask = mask.resize((H,W),Image.BILINEAR)
            mask= np.array(mask)
        
            for i in range(mask_bs):#(32 masks)
                # crop to HxW
                w_crop = np.random.randint(0, resize_offset*width + 1)
                h_crop = np.random.randint(0, resize_offset*height + 1)
#                 print("w_crop h_crop:",w_crop,h_crop)
                np_masks[i] = mask[h_crop:height + h_crop, w_crop:width + w_crop]
#                 print("mask.shape:",mask.shape)
#                 print("{}:{} {}:{}",h_crop,height + h_crop, w_crop,width + w_crop)
                if np.isnan(np.sum(np_masks[i])):
                    np_masks[i] = np_masks[0].copy() if not np.isnan(np.sum(np_masks[0])) else np_masks[1].copy()
#                 np_masks[i] /= np.max(np_masks[i])
                np_masks[i] /= 255.0

            masks = torch.from_numpy(np_masks)
            masks=masks.to(dev)
            masks = masks.resize(32,1,224,224)

            # Accumulate saliency mask.
            for i, inp in enumerate(input):
                out = torch.sigmoid(model(inp.unsqueeze(0) * masks))
                if len(out.shape) == 4:
                    assert out.shape[2] == 1
                    assert out.shape[3] == 1
                    out = out[:, :, 0, 0]
                sal = torch.matmul(out.data.transpose(0, 1),
                                   masks.view(mask_bs, height * width))
                sal = sal.view((num_classes, height, width))
                saliency[i] = saliency[i] + sal
        saliency /= num_masks

        # Restore original random number generator state.
        torch.set_rng_state(state)

        return saliency

###inference
from utils import *

# url= 'https://lh3.googleusercontent.com/proxy/sDYB4cDFBemUrYfd95Y9umcDd5b_XygAGYtY4gzKWKRKGtchmOMFcepBw7SmiJ_YVkXAvTKUYUbhc5aR3G3e7YqjLIFBPluIOEtnO6sXKTKkb-L1k52VhkIf39WHrzOWikG7Oc6dPJG9y8q0mpc'
# model_name='vgg16'
model_name='resnet50'
# Obtain example data.

img_file="samples/sample.jpg"
model, x= get_example_data(img_file,model_name)

# Load label texts for ImageNet predictions so we know what model is predicting
idx2label, cls2label, cls2idx = [], {}, {}
with open(os.path.abspath('./dataset/imagenet/imagenet_class_index.json'), 'r') as read_file:
    class_idx = json.load(read_file)
    idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
    cls2label = {class_idx[str(k)][0]: class_idx[str(k)][1] for k in range(len(class_idx))}
    cls2idx = {class_idx[str(k)][0]: k for k in range(len(class_idx))} 
    
# Predicitions we got are logits. Let's pass that through softmax to get probabilities and class labels for top 5 predictions.
logits = model(x)
probs = F.softmax(logits, dim=1)
probs5 = probs.topk(5)
ids = [idx2label[probs5[1][0][0]], idx2label[probs5[1][0][1]], idx2label[probs5[1][0][2]],
       idx2label[probs5[1][0][3]], idx2label[probs5[1][0][4]]]
tuple((p, c, idx2label[c]) for p, c in zip(probs5[0][0].cpu().detach().numpy(), probs5[1][0].cpu().detach().numpy()))
print(probs5[0][0].cpu().detach().numpy(), probs5[1][0].cpu().detach().numpy())
print(ids)


plt.figure(figsize=(30,20),dpi=80)

# Plot input image
img=Image.open(img_file)
img=img.resize((224,224))
plt.subplot(2, 3, 1)
plt.title('input image', fontsize=18)
plt.imshow(img)

# Plot topk predictions's saliency
saliency =mfpp(model, x, img_file)

for i in trange(5):
    print("i:",i)
    category_id=probs5[1][0][i].cpu().detach().numpy()
    category_name=ids[i]

    heatmap= saliency[:, category_id].unsqueeze(0)
    heatmap=heatmap.cpu().data.numpy().squeeze()
    plt.subplot(2,3,i+2)
    plt.title('MFPP for category {} ({})'.format(category_name, category_id), fontsize=18)
    plt.imshow(heatmap)



FileNotFoundError: [Errno 2] No such file or directory: '/home/qing/wind/MFPP/dataset/imagenet/imagenet_class_index.json'