In [1]:
import captum
import pandas as pd
import matplotlib.pyplot as plt
import torch
import copy

from matplotlib.patches import Rectangle
from os import path
from skimage import io
from torch import optim
from torch.nn import Linear, CrossEntropyLoss
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models

In [2]:
labels_frame = pd.read_csv(path.join('..', 'input', 'inz-data-prep', 'easy_labels_and_data.csv'), index_col=0)

In [3]:
def show_image_bbox(image, bboxes: pd.DataFrame):
    plt.imshow(image)
    ax = plt.gca()
    for idx, bbox in bboxes.iterrows():
        x = (bbox[0] - bbox[2] / 2) * 512
        y = (bbox[1] - bbox[3] / 2) * 512

        rect = Rectangle((x, y), bbox[2] * 512, bbox[3] * 512, fill=False, edgecolor='r')
        ax.add_patch(rect)
        
class CropWeedDataset(Dataset):
    
    def __init__(self, labels_csv, images_dir, transform=None):
        self.labels_frame = pd.read_csv(labels_csv, index_col=0)
        self.grouped_labels_frame = self.labels_frame.groupby('filename').count()
        self.images_dir = images_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.grouped_labels_frame)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = path.join(self.images_dir,
                                self.grouped_labels_frame.iloc[idx].name)
        image = io.imread(img_name)
        label = int(self.grouped_labels_frame.iloc[idx, 0] > 0)
        
        sample = {'image': image, 'label': label}

        if self.transform:
            sample = self.transform(sample)

        return sample
    
    def get_bbox(self, idx):
        bbox = self.labels_frame.loc[self.labels_frame['filename'] == self.grouped_labels_frame.iloc[idx].name].iloc[:,2:6]
        return bbox
    
    def show_image(self, idx):
        sample = self[idx]
        bbox = self.get_bbox(idx)

        if self.transform:
            numpy_image = sample['image'].numpy().transpose(1, 2, 0)
        else:
            numpy_image = sample['image']

        show_image_bbox(numpy_image, bbox)

    
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image) / 255,
                'label': torch.asarray(label)}

In [4]:
dataset = CropWeedDataset(labels_csv=path.join('..', 'input', 'inz-data-prep', 'easy_labels_and_data.csv'),
                          images_dir=path.join('..','input','crop-and-weed-detection-data-with-bounding-boxes','agri_data', 'data'),
                          transform=ToTensor())



In [5]:
train, test = random_split(dataset, [950, 204], generator=torch.Generator().manual_seed(420))
dataloaders = {'train':DataLoader(train, batch_size=8, shuffle=True, num_workers=2), 
               'val': DataLoader(test, batch_size=8, shuffle=True, num_workers=2)}
dataset_sizes = {'train': 950,
                'val': 204}


## Load model

In [6]:
model = models.resnet18()
model.fc = Linear(512, 2)
model.load_state_dict(torch.load(path.join('models', 'resnet-18'), map_location=torch.device('cpu')))
model.eval()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [7]:
model(torch.unsqueeze(dataset[0]['image'], 0))

tensor([[-1.7695,  2.8448]], grad_fn=<AddmmBackward0>)

In [8]:
def baseline_func(input):
    return input * 0

from torchvision import transforms
normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

In [9]:
from captum.insights import AttributionVisualizer, Batch
from captum.insights.attr_vis.features import ImageFeature

In [10]:
def formatted_data_iter():
    dataset = torch.utils.data.Subset(CropWeedDataset(labels_csv=path.join('..', 'input', 'inz-data-prep', 'easy_labels_and_data.csv'),
                          images_dir=path.join('..','input','crop-and-weed-detection-data-with-bounding-boxes','agri_data', 'data'),
                          transform=ToTensor()), range(3))
    dataloader = iter(
        torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False, num_workers=2)
    )
    while True:
        sample_dict = next(dataloader)
        images = sample_dict['image']
        labels = sample_dict['label']
        yield Batch(inputs=images, labels=labels)

In [11]:
torch.utils.data.Subset(dataset, range(10))

<torch.utils.data.dataset.Subset at 0x251d0487040>

In [12]:
visualizer = AttributionVisualizer(models=[model],
                      classes=['crop', 'weed'],
                      features=[ImageFeature('photo',baseline_transforms=[baseline_func], input_transforms=[normalize])],
                      dataset=formatted_data_iter())

In [13]:
visualizer.render(debug=True)

CaptumInsights(insights_config={'classes': ['crop', 'weed'], 'methods': ['Deconvolution', 'Deep Lift', 'Guided…

Output()