# Feature Factorization

In [None]:
%load_ext autoreload
%autoreload 2

import glob
import torch
import numpy as np
import matplotlib.pyplot as plt
from utils.visualize import visualize_image
from utils.image_loader import get_image_from_fs
from utils.factorization import compute_dff, scale_explanations

In [None]:
dir_name = "images/containers"
model_name = "keypoint_rcnn"
feature_key = "f4"
image_shape = (3, 1100, 600)

### Load image features

Load the image features from an npz file

In [None]:
features = np.load(f"{dir_name}/{model_name}_{feature_key}.npz", allow_pickle=True)
features = features["image_features"]
print("Loaded image features", features.shape)

### Compute concepts and explanations

Compute concepts and explanations of the image features

In [None]:
concepts, explanations = compute_dff(features, n_components=3)
print("Concepts", concepts.shape, "Explanations", explanations.shape)

### Scale explanations

Reshape explanations so that it fits the image size

In [None]:
scaled_explanations = scale_explanations(explanations, width=image_shape[1], height=image_shape[2])

### Visualize Explanations

Create visualizations of the explanations

In [None]:
filenames = glob.glob(dir_name + "/*.jpg")
filenames.extend(glob.glob(dir_name + "/*.png"))

img, rgb_img_float, input_tensor = get_image_from_fs(
    filenames[0],
    resize=None,
)

visualizations = visualize_image(
    concepts,
    scaled_explanations[0],
    None,
    rgb_img_float,
    image_weight=0.3
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10,12))
ax.axis('off')
ax.imshow(visualizations)

In [None]:
fig.savefig('out.png', bbox_inches='tight', pad_inches=0)