In [1]:
import os
import pickle
from glob import glob

In [2]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from umap import UMAP

In [3]:
import sys
sys.path.append("../src")

In [4]:
from utils.plots import input2image
from utils import plots

In [5]:
analysis_root = "/data2/genta/resnet/analysis/"
# rf_root = "/mnt/nas3/lab_member_directories/2021_genta/resnet/e_receptive_field/"
rf_root = "/mnt/nas5/lab_member_directories/2021_genta/resnet/e_receptive_field/"

In [6]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

In [9]:
from IPython.display import display

class Visualizer:
    def __init__(self, analysis_root=None, rf_root=None, model_name=None, key_layer=None, block_id=None, dr_method_name="PCA"):
        """
        """
        if analysis_root is None:
            analysis_root = "/data2/genta/resnet/analysis/"
        if rf_root is None:
            rf_root = "/mnt/nas3/lab_member_directories/2021_genta/resnet/e_receptive_field/"
        self.analysis_root = analysis_root
        self.rf_root = rf_root
        
        if model_name is not None and key_layer is not None and block_id is not None:
            self.set_vis_layer(model_name, key_layer, block_id)
        else:
            self.act_dir_path = None
            self.rf_dir_path = None
            
        self.dr_method_name = dr_method_name
        self.f = None

    def _show_glob_paths(self, model_name, key_layer, block_id):
        cond_path = glob(os.path.join(self.analysis_root, model_name + "*", key_layer + "*" + str(block_id)))
        print("analysis path")
        print(cond_path)
        cond_path = glob(os.path.join(self.rf_root, "*" + model_name + "*" + key_layer + "." + str(block_id) + "*"))
        print("rf path")
        print(cond_path)
        
    
    def set_vis_layer(self, model_name, key_layer, block_id):
        if self.analysis_root is None or self.rf_root is None:
            raise ValueError
            
        cond_path = glob(os.path.join(self.analysis_root, model_name + "*", key_layer + "*" + str(block_id)))
        assert len(cond_path) != 0, "Not found"
        assert len(cond_path) == 1, "ambiguous; model name or key_layer or block_id. you can check paths to use _show_glob_paths"
        self.act_dir_path = cond_path[0]

        cond_path = glob(os.path.join(self.rf_root, "*" + model_name + "*" + key_layer + "." + str(block_id) + "*"))
        assert len(cond_path) != 0, "Not found"
        assert len(cond_path) == 1, "ambiguous; model name or key_layer or block_id. you can check paths to use _show_glob_paths"
        self.rf_dir_path = cond_path[0]
    
    def dimension_reduction_methods(self, n_neighbors=None, **kwargs):
        tmp_name = self.dr_method_name.upper()
        if tmp_name == "PCA":
            self.dr_model = PCA(svd_solver="full")
        elif tmp_name == "UMAP":
            self.dr_model = UMAP(n_neighbors=n_neighbors, **kwargs)
        elif tmp_name == "TSNE":
            self.dr_model = TSNE(**kwargs)
        else:
            raise ValueError(tmp_name)
        
    
    def set_ch_data(self, ch):
        if self.act_dir_path is None or self.rf_dir_path is None:
            msg = "do set_vis_layer"
            raise ValueError(msg)
        
        path = os.path.join(self.act_dir_path, "act_preact_{:03}.pkl".format(ch))
        with open(path, "rb") as f:
            self.act_preact = pickle.load(f)

        path = os.path.join(self.rf_dir_path, "top_rf_datas", "rfimgs-{:03}.npy".format(ch))
        self.rfimgs = np.load(path)
        path = os.path.join(self.rf_dir_path, "top_rf_datas", "rfgrads-{:03}.npy".format(ch))
        self.rfgrads = np.load(path)

        
    def set_scatter_data(self, n_neighbors=None, random_state=1119, N=100, dr_method_name=None, specific_pos=None):
        if dr_method_name is not None:
            
            self.dr_method_name = dr_method_name
        self.dimension_reduction_methods(n_components=2, n_neighbors=n_neighbors, random_state=random_state)
        if specific_pos is None:
            self.scatter_data = self.dr_model.fit_transform(self.act_preact[1])
        else:
            rf_h = 5
            rf_w = 5
            tmp_data = self.act_preact[1]
            self.scatter_data = self.dr_model.fit_transform(tmp_data.reshape(len(tmp_data), -1, rf_h, rf_w)[..., specific_pos[0], specific_pos[1]])


        images = self.rfimgs[:N]
        images = input2image(images)
        self.images = np.transpose(images, (0, 2, 3, 1)) * 256
        
        images = self.rfimgs[:N] * np.abs(plots.normalize_inputspace(self.rfgrads[:N]))            
        images = input2image(images)
        self.erfimages = np.transpose(images, (0, 2, 3, 1)) * 256
        
        self.f = go.FigureWidget(make_subplots(rows=1, cols=3))
        self.f.add_trace(go.Scatter(x=self.scatter_data[:, 0], y=self.scatter_data[:, 1], mode='markers'), row=1, col=1)
        self.f.add_trace(go.Image(z=np.ones_like(self.images[0])), row=1, col=2)
        self.f.add_trace(go.Image(z=np.ones_like(self.images[0])), row=1, col=3)

        colors = plots.get_colors(N=N)
        colors = ["#{:02x}{:02x}{:02x}".format(int(r * 255), int(g * 255), int(b * 255)) for r, g, b, _ in colors]
        default_colors = tuple(colors[::-1])
        default_size = (10, ) * N

        self.scatter = self.f.data[0]
        self.scatter.marker.color = default_colors
        self.scatter.marker.size = default_size
        f_images = self.f.data[1]
        f_erfimages = self.f.data[2]

        self.f.layout.hovermode = 'closest'


        # create our callback function
        def update_point(trace, points, selector):
            i = points.point_inds[0]    
            colors = list(default_colors)
            marker_size = list(default_size)
            colors[i] = "red"
            marker_size[i] = 20

            self.scatter.marker.color = colors
            self.scatter.marker.size = marker_size
            f_images.z = self.images[i]
            f_erfimages.z = self.erfimages[i]

        self.scatter.on_click(update_point)

    def view_figure(self):
        if self.f is None:
            msg = "do set_ch_data & set_scatter_data"
            raise ValueError(msg)
        
        display(self.f)
        
    def show_all(self, mode="rfimgs", N=100):
        if mode == "rfimgs":
            images = self.rfimgs[:N]
        elif mode == "erfimgs":
            images = self.rfimgs[:N] * np.abs(plots.normalize_inputspace(self.rfgrads[:N]))            
        images = input2image(images)
        plots.plot_imshows(images, show_flag=True)

In [10]:
# ch_list = [206, 428, 212, 467, 435, 140, 461, 284, 459, 242]
ch_list = [206, 428, 212, 467]
# ch_list = [206]
viewers = [Visualizer(model_name="resnet34", key_layer="layer4", block_id=2, dr_method_name="UMAP") for _ in ch_list]
for cnt, ch in enumerate(ch_list):
    viewers[cnt].set_ch_data(ch)
    viewers[cnt].set_scatter_data(n_neighbors=10, specific_pos=(2, 2))
    viewers[cnt].view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

In [10]:
# ch_list = [206, 428, 212, 467, 435, 140, 461, 284, 459, 242]
ch_list = [206, 428, 212, 467]
# ch_list = [206]
viewers = [Visualizer(model_name="resnet34", key_layer="layer4", block_id=2, dr_method_name="UMAP") for _ in ch_list]
for cnt, ch in enumerate(ch_list):
    viewers[cnt].set_ch_data(ch)
    viewers[cnt].set_scatter_data(n_neighbors=10, specific_pos=(2, 2))
    viewers[cnt].view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

In [12]:
# ch_list = [466,  37,  89, 104, 441, 421, 159]
ch_list = [206, 428, 212, 467]
viewers2 = [Visualizer(model_name="resnet34", key_layer="layer4", block_id=2, dr_method_name="UMAP") for _ in ch_list]
for cnt, ch in enumerate(ch_list):
    viewers2[cnt].set_ch_data(ch)
    viewers2[cnt].set_scatter_data(n_neighbors=10)
    viewers2[cnt].view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

In [14]:
ch_list = [466]
viewers3 = [Visualizer(model_name="plainnet34", key_layer="layer4", block_id=2, dr_method_name="UMAP") for _ in ch_list]
for cnt, ch in enumerate(ch_list):
    viewers3[cnt].set_ch_data(ch)
    viewers3[cnt].set_scatter_data(n_neighbors=10)
    viewers3[cnt].view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

In [17]:
resnet34_view = Visualizer(model_name="resnet34", key_layer="layer1", block_id=2, dr_method_name="UMAP")

resnet34_view.set_ch_data(62)
resnet34_view.set_scatter_data(n_neighbors=10)

resnet34_view.view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

# PCA

In [27]:
resnet34_view = Visualizer(model_name="resnet34", key_layer="layer4", block_id=0, dr_method_name="PCA")

resnet34_view.set_ch_data(147)
resnet34_view.set_scatter_data()

resnet34_view.view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

In [30]:
resnet34_view = Visualizer(model_name="resnet34", key_layer="layer4", block_id=0, dr_method_name="UMAP")

resnet34_view.set_ch_data(147)
resnet34_view.set_scatter_data(n_neighbors=15)

resnet34_view.view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

In [31]:
resnet34_view = Visualizer(model_name="resnet34", key_layer="layer4", block_id=0, dr_method_name="UMAP")

resnet34_view.set_ch_data(148)
resnet34_view.set_scatter_data(n_neighbors=15)

resnet34_view.view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

In [11]:
resnet34_view = Visualizer(model_name="resnet34", key_layer="layer4", block_id=2, dr_method_name="PCA")

resnet34_view.set_ch_data(38)
resnet34_view.set_scatter_data()

resnet34_view.view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

In [10]:
plainnet34_view = Visualizer(model_name="plainnet34", key_layer="layer4", block_id=2, dr_method_name="PCA")

plainnet34_view.set_ch_data(489)
plainnet34_view.set_scatter_data()

plainnet34_view.view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

# UMAP

In [25]:
resnet34_view = Visualizer(model_name="resnet34", key_layer="layer4", block_id=2, dr_method_name="UMAP")

resnet34_view.set_ch_data(38)
resnet34_view.set_scatter_data(n_neighbors=15)

resnet34_view.view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

In [26]:
plainnet34_view = Visualizer(model_name="plainnet34", key_layer="layer4", block_id=2, dr_method_name="UMAP")

plainnet34_view.set_ch_data(38)
plainnet34_view.set_scatter_data(n_neighbors=15)

plainnet34_view.view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

In [19]:
resnet34_view = Visualizer(model_name="resnet34", key_layer="layer4", block_id=2, dr_method_name="UMAP")

resnet34_view.set_ch_data(358)
resnet34_view.set_scatter_data(n_neighbors=15)

resnet34_view.view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

In [20]:
plainnet34_view = Visualizer(model_name="plainnet34", key_layer="layer4", block_id=2, dr_method_name="UMAP")

plainnet34_view.set_ch_data(489)
plainnet34_view.set_scatter_data(n_neighbors=15)

plainnet34_view.view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

In [10]:
resnet34_view = Visualizer(model_name="resnet34", key_layer="layer4", block_id=2, dr_method_name="UMAP")

resnet34_view.set_ch_data(358)
resnet34_view.set_scatter_data(n_neighbors=10)

resnet34_view.view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …

In [9]:
plainnet34_view = Visualizer(model_name="plainnet34", key_layer="layer4", block_id=1, dr_method_name="PCA")

plainnet34_view.set_ch_data(232)
plainnet34_view.set_scatter_data()

plainnet34_view.view_figure()

FigureWidget({
    'data': [{'marker': {'color': [#fde724, #f8e621, #f1e51c, #ece41a, #e4e318,
               …