In [1]:
import torch

import numpy as np
from utils import nethook, imgviz, show,tally


#https://github.com/SIDN-IAP/global-model-repr
#https://github.com/davidbau/dissect

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Running pytorch', torch.__version__, 'using', device.type)

Get the model

In [None]:
quantile=0.01
percent_level = 1.0 - quantile

In [None]:
import os
import torch.nn as nn
import logging
from torchvision import models

def get_model(device,base_model=None):
    '''
       Gets VGG16 model
       :param base_model: path to pre initialized model
       :return: vgg16 model with last layer modification (2 classes)
       '''
    model = models.vgg16(pretrained=True)
    #print(models.vgg16(pretrained=True))

    # Freeze trained weights
    for param in model.features.parameters():
        param.requires_grad = False
    # Newly created modules have require_grad=True by default 
    num_features = model.classifier[6].in_features
    features = list(model.classifier.children())[:-1]  # Remove last layer
    linear = nn.Linear(num_features, 2)

    features.extend([linear])  # Add our layer with 2 outputs
    model.classifier = nn.Sequential(*features)  # Replace the model classifier
    #print(model)
    # Load pre initialized model
    if base_model and os.path.exists(base_model):
        model.load_state_dict(torch.load(base_model, map_location=device))
        logging.info(f'Loading {base_model}')
    else:
        logging.info(f'Loading pretrained VGG16 model')

    return model



Load the dataset 

In [None]:
def eval(model,img_tensor):
    model.eval()
    with torch.no_grad():
        img_tensor = img_tensor.cuda()
        output = model(img_tensor)
        score, predicted = torch.max(output.data, 1)
        out = [predicted.item(),score.item()]
    return out

#print(eval(model,test_data_loader[0][0][None])[0])

In [None]:
from typing import Tuple, List, Dict
from torchvision import datasets, transforms

class CustomImageFolder(datasets.ImageFolder):
    def __init__(self, dataset, transform=None):
        super(CustomImageFolder, self).__init__(dataset, transform=transform)

    def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
        classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        classes.sort()
        classes = [item for item in reversed(classes)]
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def __getitem__(self, index):
        sample, label = super(datasets.ImageFolder, self).__getitem__(index)
        return sample, label, self.imgs[index]


def load_and_transform_data(dataset, batch_size=1, data_augmentation=False):
    # Define transformations that will be applied to the images
    # VGG-16 Takes 224x224 images as input, so we resize all of them
    logging.info(f'Loading data from {dataset}')
    
    mean=[0.48, 0.24, 0.12]
    std=[0.27, 0.14, 0.08]
  

    data_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4387, 0.3090, 0.2211], std=[0.2733, 0.2035, 0.1717]),
    ])

    dataset = CustomImageFolder(dataset, transform=data_transforms)
    #data_loader = torch.utils.data.DataLoader(image_datasets, batch_size=batch_size, shuffle=True, num_workers=4)

    logging.info(f'Loaded {len(dataset)} images under {dataset}: Classes: {dataset.class_to_idx}')


    return dataset

In [None]:
model = get_model('/home/aharris/shared/EyePACS/models/exp9/weights_190.pth')
model.to(device=device)

test_image_folder = '/home/aharris/shared/EyePACS/input/image/dynamic_run/test'

test_data_loader = load_and_transform_data(test_image_folder)

#renorm = renormalize.renormalizer(source=test_data_loader, target='zc')
ivsmall = imgviz.ImageVisualizer((56, 56), source=test_data_loader, percent_level=0.99)

In [None]:
iv = imgviz.ImageVisualizer(224, image_size= (224,224),source=test_data_loader, percent_level=0.99)
show(iv.image(test_data_loader[0][0]))

## Examine raw unit activations.

Look at individual activations


In [None]:
layername = 'features.28'
model = nethook.InstrumentedModel(model)
model.retain_layer(layername)
indexes = range(0, 263)
batch = torch.stack([test_data_loader[i][0] for i in indexes])
model(batch.cuda())
acts = model.retained_layer(layername).cpu()
show([
    [
        [ivsmall.masked_image(batch[imagenum], acts[imagenum], unitnum)],
        [ivsmall.heatmap(acts[imagenum], unitnum, mode='nearest')],
        'unit %d' % unitnum
    ]
    for unitnum in range(acts.shape[1])
    for imagenum in [22]
])  

## Examine images that maximize each unit
The loop below identifies the images, out of a sample of 30, that cause each filter to activate strongest. 

In [None]:
sample_size = 263
def max_activations(batch, *args):
    image_batch = batch.to(device)
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    return acts.view(acts.shape[:2] + (-1,)).max(2)[0]

def mean_activations(batch, *args):
    image_batch = batch.to(device)
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    return acts.view(acts.shape[:2] + (-1,)).mean(2)

topk = tally.tally_topk(
    max_activations,
    dataset=test_data_loader,
    sample_size=sample_size,
    cachefile='results/cache_mean_topk.npz'
)

top_scores = topk.result()[0]
top_indexes =topk.result()[1]

In [None]:
top_array = np.ndarray([512,30],dtype=object)
for i in range(512):
    for j in range(30):
        top_array[i,j] = [top_indexes[i][j].item(),top_scores[i][j].item()]

Loop that runs the model for each of the top-activating images for a particular unit (12), and then shows where that unit activates within the images.

In [None]:
#iv = imgviz.ImageVisualizer(224, image_size= (224,224),source=test_data_loader, level=rq.quantiles(percent_level),quantiles=rq)
iv = imgviz.ImageVisualizer(224, image_size= (224,224),source=test_data_loader)
#iv = imgviz.ImageVisualizer(224, image_size= (224,224),source=test_data_loader, percent_level=0.99)

In [None]:
out_path = '/home/aharris/shared/EyePACS/interpretability/results_dissection'

In [None]:
for u in range(1,512):
    img = show.blocks([
        ['unit %d' % u,
        'img %d' % top_array[u][j][0],
        'pred: %s' % test_data_loader.classes[eval(model,test_data_loader[top_array[u][j][0]][0][None])[0]],
        'score: %f' %top_array[u][j][1],
        'GroundTruth: %s' %str(test_data_loader[top_array[u][j][0]][2]).split('/')[3],
        [iv.masked_image(
            test_data_loader[top_array[u][j][0]][0],
            model.retained_layer(layername)[0],
            u)]
        ]
        for j in range(30)
        ])


    html = img.data
    with open(f'{out_path}/{u}_unit_top30.html', 'w') as f:
        f.write(html)

    