In [None]:
%reload_ext autoreload
%autoreload 2

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
from PIL import Image
import numpy as np
from sklearn.decomposition import NMF
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import os

from backbones.resnet18 import ResNet18
from backbones.resnet50 import ResNet50
from backbones.vgg16 import VGG16
from backbones.vgg19 import VGG19
from backbones.mobilenetv3small import MobileNetV3Small
import utils.tensor as tensor_utils

In [None]:
def preprocess_image(image_path):
    image = Image.open(image_path)
    preprocess = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    input_tensor = preprocess(image)
    input_batch = input_tensor.unsqueeze(0)
    return input_batch

base_dir = "images/flower_objects"
paths = os.listdir(base_dir)
paths = [f for f in paths if os.path.isfile(f"{base_dir}/{f}")]
input_batches = torch.cat([preprocess_image(base_dir + "/" + path) for path in paths])

In [None]:
input_batches.shape

In [None]:
model = VGG16()
layer_num = 4
n_components=2

In [None]:
# Get the model activations for each image
activations = model.get_features(input_batches, layer_num)
batch_size, channels, h, w = activations.shape
print("ACT", activations.shape)

In [None]:
activations.shape

In [None]:
tsne = TSNE(n_components=2)
tsne_embeddings = tsne.fit_transform(activations.reshape(activations.shape[0], -1))
activations.reshape(activations.shape[0], -1).shape
plt.scatter(tsne_embeddings[:, 0], tsne_embeddings[:, 1])

In [None]:
# Reshape the activation tensor
reshaped_activations = torch.permute(activations, (1, 0, 2, 3))
print("RE_ACT", reshaped_activations.shape)
# Eliminate NaN values
reshaped_activations[np.isnan(reshaped_activations)] = 0
reshaped_activations = reshaped_activations.reshape(reshaped_activations.shape[0], -1)
print("RE_ACT2", reshaped_activations.shape)
offset = reshaped_activations.min(axis=1)[0].reshape(-1, 1)
print("OFFSET", offset, offset.shape)
reshaped_activations = reshaped_activations - offset

model_nmf = NMF(n_components=n_components, init='random', random_state=0)
W = model_nmf.fit_transform(reshaped_activations)
H = model_nmf.components_
print("W", W.shape, "offset", offset.shape)
concepts = W + offset.detach().cpu().numpy()
explanations = H.reshape(n_components, batch_size, h, w)
explanations = explanations.transpose((1, 0, 2, 3))
concepts.shape, explanations.shape

In [None]:
batch_size, channels, h, w = explanations.shape
explanations_reshaped = explanations.reshape(batch_size, channels, -1)
logits = torch.nn.functional.softmax(torch.tensor(explanations_reshaped), dim=2)
logits = logits.reshape(batch_size, channels, h, w).detach().cpu().numpy()
print(logits.shape)

fig, ax = plt.subplots(n_components, 2)

for i in range(n_components):
    ax[i, 0].imshow(logits[0, i])
    ax[i, 1].imshow(logits[0, 0] > logits[0, 1])

In [None]:
feature_map_height = feature_maps[0].shape[1]
feature_map_size = feature_map_height**2
feature_map_dim = feature_maps[0].shape[3]

print("feature_map_height", feature_map_height, "feature_map_dim", feature_map_dim)

In [None]:
ndarrays = [f.detach().cpu().numpy().reshape(feature_map_size, feature_map_dim) for f in feature_maps]

In [None]:
d = np.concatenate(ndarrays)
d.shape

In [None]:
nmf = NMF(n_components=2, max_iter=200)
nmf.fit(d)
nmf_features = nmf.transform(d)

In [None]:
nmf_features.shape

In [None]:
nmf.components_.shape, nmf.reconstruction_err_

In [None]:
dist1 = np.linalg.norm(d - nmf.components_[0], axis=1)
dist2 = np.linalg.norm(d - nmf.components_[1], axis=1) 

In [None]:
clusters = (dist1 < dist2).astype(int)

In [None]:
i = 0
fig, ax = plt.subplots(1, 2, figsize=(8,8))
ax[0].imshow(np.asarray(Image.open(paths[i])))
ax[1].imshow(clusters[i*feature_map_size:(i+1)*feature_map_size].reshape(feature_map_height, feature_map_height))

In [None]:
tsne = TSNE(n_components=2)
embeddings = tsne.fit_transform(nmf_features)

In [None]:
plt.scatter(embeddings[:, 0], embeddings[:, 1])