In [1]:
# Enable autoreload
%load_ext autoreload
%autoreload 2

# LIME signature for remote model

We have to adapt the LIME signature computation to work with the MLaaS model.
We will generate first all the perturbed points that LIME needs to generate its local models.

Then, we will transform all vectors to images, save them as jpg files and upload them to the MLaaS platform. 
With that done, we will be able to perform a batch classification operation and obtain the results for each point.

In [2]:
import os
os.chdir('../')

In [3]:
import json
import torch
import torchvision
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from torchvision import transforms

from zest import utils
from zest import model
from zest import lime_pytorch

In [4]:
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406], 
    std=[0.229, 0.224, 0.225]
)
inv_normalize = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
    std=[1/0.229, 1/0.224, 1/0.255]
)
transform = transforms.Compose([transforms.ToTensor(), normalize])

def show_torch_float(i):
    a = transforms.functional.to_pil_image(i.to(torch.uint8))
    plt.imshow(a)
    plt.show()

def show_torch_advex(img):
    # img = img / 2 + 0.5     # unnormalize
    img = inv_normalize(img)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [5]:
dataset = 'CIFAR10'
batch_size = 32
dist = ['1', '2', 'inf', 'cos']
lime_data_name = f"{dataset}_{batch_size}_lime"
save_name = lime_data_name

We need to reproduce the behavior of the `lime()` function.

```python
# Original LIME function from Zest 
def lime(self, save_name=None, cat=True):
    if save_name is None:
        save_name = self.lime_data_name
    self.net.eval()
    if self.lime_data is None:
        self.lime_data = lime_pytorch.prepare_lime_ref_data(save_name, self.trainset, self.batch_size)
    if self.lime_segment is None:
        self.lime_segment = lime_pytorch.prepare_lime_segment(save_name, self.lime_data, self.trainset)
    if self.ref_dataset is None or self.lime_dataset is None:
        self.ref_dataset, self.lime_dataset = lime_pytorch.prepare_lime_dataset(save_name, self.lime_data,
                                                                                self.lime_segment)
    self.lime_mask = lime_pytorch.compute_lime_signature(self.net, self.ref_dataset, self.lime_dataset, cat=cat)
    self.net.train()
```

In [6]:
trainset = utils.load_dataset(dataset, True, download=True)
lime_data = lime_pytorch.prepare_lime_ref_data(save_name, trainset, batch_size)

Files already downloaded and verified


In [7]:
lime_segment = lime_pytorch.prepare_lime_segment(save_name, lime_data, trainset)

In [8]:
print(lime_data.shape)
print(lime_segment.shape)

(32, 32, 32, 3)
(32, 32, 32)


In [9]:
ref_dataset, lime_dataset = lime_pytorch.prepare_lime_dataset(save_name, lime_data, lime_segment)

In [10]:
print(ref_dataset.shape)
print(lime_dataset.shape)
print(lime_dataset[0].shape)

(32, 1000, 32, 32, 3)
(32,)
(1000, 23)


We can't call `compute_lime_signature()` directly because it internally calls `label_lime_dataset()` which requires a model to be loaded. So we will adapt it to the remote model.


```python
# Original label data function from Zest
def label_lime_dataset(lime_dataset, ref_dataset, model):
    device = torch.device('cuda:0' if next(model.parameters()).is_cuda else 'cpu')
    datasets = []
    with torch.no_grad():
        for i in range(len(lime_dataset)):
            lime_data = lime_dataset[i]
            data = ref_dataset[i]
            inputs = torch.from_numpy(data).to(device).permute(0, 3, 1, 2).float()
            outputs = model(inputs).detach().cpu().numpy()
            datasets.append([lime_data, outputs])
    return datasets
```

We need the model outputs for each sample in the `ref_dataset` array.

In [11]:
base_save_path = '/net/data/fedpois/lime{}'.format(batch_size)
os.makedirs(base_save_path, exist_ok=True)
jsonl_file = 'data/lime{}.jsonl'.format(batch_size)
remote_base_pth = 'gs://bad-lemon-vcm/lime{}/'.format(batch_size)
print(base_save_path)
print(remote_base_pth)
print(jsonl_file)

# Example
# {"content": "gs://sourcebucket/datasets/images/source_image.jpg", "mimeType": "image/jpeg"}

/net/data/fedpois/lime32
gs://bad-lemon-vcm/lime32/
data/lime32.jsonl


**Waning**: This creates a **lot** of jpg files. Make sure you have setup the paths correctly, and only execute this once.

In [12]:
with open(jsonl_file, 'w') as mf:    

    for base_img in tqdm(range(ref_dataset.shape[0])):
        cur_base = torch.from_numpy(ref_dataset)[base_img].permute(0, 3, 1, 2)
        # print('Shape of the current base image tensor:', cur_base.shape)

        for pert_idx, pert_img in enumerate(cur_base):
            unnorm = inv_normalize(pert_img).float()
            # print('Shape of the {}-th perturbation image tensor: {}'.format(pert_idx, unnorm.shape))

            unnorm_pil = transforms.functional.to_pil_image(unnorm)

            save_pth = os.path.join(base_save_path, f'{base_img}_{pert_idx}.jpg')
            unnorm_pil.save(save_pth)

            remote_name = remote_base_pth + '{}_{}.jpg'.format(base_img, pert_idx)

            line_base = {"content": remote_name, "mimeType": "image/jpeg"}
            mf.write(json.dumps(line_base) + '\n')

  0%|          | 0/32 [00:00<?, ?it/s]

In [13]:
# plt.imshow( transforms.functional.to_pil_image( inv_normalize( torch.from_numpy(ref_dataset)[0].permute(0, 3, 1, 2)[0] ).float() ) )

In [14]:
# transforms.functional.pil_to_tensor( transforms.functional.to_pil_image( inv_normalize( torch.from_numpy(ref_dataset)[0].permute(0, 3, 1, 2)[0] ).float() ) )