# Deep Feature Factorization

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torchvision
import numpy as np
import cv2
import os
import requests
import random
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from PIL import Image
from utils.image_loader import get_image_from_url, get_image_from_fs

from pytorch_grad_cam import DeepFeatureFactorization
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image, deprocess_image

In [None]:
# Load the model weights
model = torchvision.models.densenet201(pretrained=True)

# Set the model to 'evaluation' mode, that means freeze the weights
model.eval()

''

In [None]:
def create_labels(concept_scores, top_k=2):
    """ Create a list with the image-net category names of the top scoring categories"""
    imagenet_categories_url = \
        "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
    labels = eval(requests.get(imagenet_categories_url).text)
    concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k]
    concept_labels_topk = []
    for concept_index in range(concept_categories.shape[0]):
        categories = concept_categories[concept_index, :]    
        concept_labels = []
        for category in categories:
            score = concept_scores[concept_index, category]
            label = f"{labels[category].split(',')[0]}:{score:.2f}"
            concept_labels.append(label)
        concept_labels_topk.append("\n".join(concept_labels))
    return concept_labels_topk

def show_factorization_on_image(img: np.ndarray,
                                explanations: np.ndarray,
                                colors: list[np.ndarray] = None,
                                image_weight: float = 0.5,
                                concept_labels: list = None) -> np.ndarray:
    """ Color code the different component heatmaps on top of the image.
        Every component color code will be magnified according to the heatmap itensity
        (by modifying the V channel in the HSV color space),
        and optionally create a lagend that shows the labels.
        Since different factorization component heatmaps can overlap in principle,
        we need a strategy to decide how to deal with the overlaps.
        This keeps the component that has a higher value in it's heatmap.
        
        Taken from https://github.com/jacobgil/pytorch-grad-cam/blob/2183a9cbc1bd5fc1d8e134b4f3318c3b6db5671f/pytorch_grad_cam/utils/image.py#L83
    :param img: The base image RGB format.
    :param explanations: A tensor of shape num_componetns x height x width, with the component visualizations.
    :param colors: List of R, G, B colors to be used for the components.
                   If None, will use the gist_rainbow cmap as a default.
    :param image_weight: The final result is image_weight * img + (1-image_weight) * visualization.
    :concept_labels: A list of strings for every component. If this is paseed, a legend that shows
                     the labels and their colors will be added to the image.
    :returns: The visualized image.
    """
    n_components = explanations.shape[0]
    if colors is None:
        # taken from https://github.com/edocollins/DFF/blob/master/utils.py
        _cmap = plt.cm.get_cmap('gist_rainbow')
        colors = [
            np.array(
                _cmap(i)) for i in np.arange(
                0,
                1,
                1.0 /
                n_components)]
    concept_per_pixel = explanations.argmax(axis=0)
    masks = []
    for i in range(n_components):
        mask = np.zeros(shape=(img.shape[0], img.shape[1], 3))
        mask[:, :, :] = colors[i][:3]
        explanation = explanations[i]
        explanation[concept_per_pixel != i] = 0
        mask = np.uint8(mask * 255)
        mask = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV)
        mask[:, :, 2] = np.uint8(255 * explanation)
        mask = cv2.cvtColor(mask, cv2.COLOR_HSV2RGB)
        mask = np.float32(mask) / 255
        masks.append(mask)

    mask = np.sum(np.float32(masks), axis=0)
    result = img * image_weight + mask * (1 - image_weight)
    result = np.uint8(result * 255)

    if concept_labels is not None:
        px = 1 / plt.rcParams['figure.dpi']  # pixel in inches
        fig = plt.figure(figsize=(result.shape[1] * px, result.shape[0] * px))
        plt.rcParams['legend.fontsize'] = int(
            14 * result.shape[0] / 256 / max(1, n_components / 6))
        lw = 5 * result.shape[0] / 256
        lines = [Line2D([0], [0], color=colors[i], lw=lw)
                 for i in range(n_components)]
        plt.legend(lines,
                   concept_labels,
                   mode="expand",
                   fancybox=True,
                   shadow=True)

        plt.tight_layout(pad=0, w_pad=0, h_pad=0)
        plt.axis('off')
        fig.canvas.draw()
        data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        plt.close(fig=fig)
        data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        data = cv2.resize(data, (result.shape[1], result.shape[0]))
        result = np.hstack((result, data))
    return result

def dff(model, target_layer, computation_on_concepts, input_tensor, n_components):
    dff_model = DeepFeatureFactorization(
        model=model,
        target_layer=target_layer, 
        computation_on_concepts=computation_on_concepts
    )
    if computation_on_concepts:
        concepts, batch_explanations, concept_outputs = dff_model(input_tensor, n_components)
    else:
        concepts, batch_explanations = dff_model(input_tensor, n_components)
        concept_outputs = None
    return concepts, batch_explanations, concept_outputs
        

def visualize_image(
    concepts,
    batch_explanation,
    concept_outputs,
    rgb_image_float,
    image_weight=0.3
):  
    if concept_outputs is not None:
        concept_outputs = torch.softmax(
            torch.from_numpy(concept_outputs),
            axis=-1
        ).numpy()    
    
    visualization = show_factorization_on_image(rgb_image_float, 
                                                batch_explanation,
                                                image_weight=image_weight,
                                                concept_labels=None)
        
    return visualization

In [None]:
images = []
rgb_img_floats = []
input_tensors = []
base_dir = "images"
dir_name = "containers"
n_components=2
max_images=7
resize=(395, 395)

for i, filename in enumerate(os.listdir(base_dir + "/" + dir_name)):
    if i >= max_images:
        break
    print(filename)
    img, rgb_img_float, input_tensor = get_image_from_fs(
        base_dir + "/" + dir_name + "/" + filename,
        resize=resize,
    )
    images.append(img)
    rgb_img_floats.append(rgb_img_float)
    input_tensors.append(input_tensor)
    
input_tensor = torch.vstack(input_tensors)

In [None]:
concepts, batch_explanations, concept_outputs = dff(
    model=model,
    target_layer=model.features.denseblock4,
    computation_on_concepts=None,#model.classifier,
    input_tensor=input_tensor,
    n_components=n_components
)

In [None]:
dff_data = {
    "concepts": concepts,
    "batch_explanations": batch_explanations,
    "concept_outputs": concept_outputs,
}

np.savez_compressed(f"dff_{dir_name}_concepts_{n_components}.npz", concepts)
np.savez_compressed(f"dff_{dir_name}_batch_explanations_{n_components}.npz", batch_explanations)
np.savez_compressed(f"dff_{dir_name}_concept_outputs_{n_components}.npz", concept_outputs)

In [None]:
imgs = []
visualizations = []

for i in range(len(batch_explanations)):
    img, visualization = visualize_image(
        concepts=concepts,
        batch_explanation=batch_explanations[i],
        concept_outputs=concept_outputs,
        rgb_image_float=rgb_img_floats[i],
        image_weight=0.3,
    )
    imgs.append(img)
    visualizations.append(visualization)

In [None]:
#fig, ax = plt.subplots(len(batch_explanations), 2, figsize=(8,64))

for i in range(len(imgs)):
    #ax[i, 0].imshow(images[i])
    #ax[i, 1].imshow(Image.fromarray(visualizations[i]))
    
    fig2, ax2 = plt.subplots(1, 2, figsize=(8,8))
    ax2[0].imshow(images[i])
    ax2[1].imshow(Image.fromarray(visualizations[i]))
    fig2.savefig(f"visualizations/{dir_name}/{i}" + ".jpg", bbox_inches='tight', pad_inches=0)

## Inspection

In [None]:
def cosine_similarity(v1, v2, mode='deg'):
    # Compute the dot product of the vectors
    dot_product = np.dot(v1, v2)

    # Compute the magnitudes of the vectors
    magnitude_v1 = np.linalg.norm(v1)
    magnitude_v2 = np.linalg.norm(v2)

    # Compute the cosine of the angle between the vectors
    cosine_angle = dot_product / (magnitude_v1 * magnitude_v2)

    # Compute the angle in radians
    angle_radians = np.arccos(cosine_angle)

    # Convert the angle to degrees
    angle_degrees = np.degrees(angle_radians)
    
    if mode == 'deg':
        return angle_degrees
    if mode == 'rad':
        return angle_radians
    
def concept_similarity_matrix(concepts):
    dim = concepts.shape[1]
    matrix = np.zeros((dim, dim))
    for i in range(dim):
        for j in range(dim):
            if i == j:
                matrix[i, j] = 0.0
            else:
                dist = cosine_similarity(concepts[:, i], concepts[:, j])
                matrix[i, j] = dist
    return matrix

In [None]:
m = concept_similarity_matrix(concepts)
sns.heatmap(m, annot=True, cmap='YlOrRd', cbar=False, linewidths=0.5, fmt=".2g")